import json
import argparse
import warnings
from codebleu import calc_codebleu
import os
import time
from utils_for_llm import *

# Define prediction function using accelerate with distributed processing
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl

from vllm import SamplingParams
import re
from tqdm import tqdm

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ["CUDALIB_PATH"] = "/miniforge3/envs/vllm/bin"
os.environ["TRITON_PTXAS_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "ptxas")
os.environ["TRITON_CUOBJDUMP_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "cuobjdump")
os.environ["TRITON_NVDISASM_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "nvdisasm")
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser()
parser.add_argument("--task", default="code_generation", type=str)
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--model_version", default=3.1, type=float)
parser.add_argument("--model_size", default=8, type=float)

parser.add_argument("--debug", action="store_true")

args = parser.parse_args()
task = args.task # or task_breakdown


with open('../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)
with open('../data/identifier2python.pkl', 'rb') as fp:
    identifier2python = pickle.load(fp)



max_seq_length = 8192
if task == "code_generation":
    format_instruction = format_instruction_with_code
    target_col = "code"
elif task == "task_breakdown":
    format_instruction = format_instruction_without_code
    max_seq_length = 768
    target_col = "description"
else:
    raise Exception(f'{task} is not defined.')



def predict_on_validation_BATCH_vllm(llm, eval_dataset, external_data=False):

    for split in ['val', 'ood']:
        eval_dataset[split] = eval_dataset[split].to_dict(orient='records')

        eval_inputs = []

        for sample in eval_dataset[split]:
            prompt = sample['prompt']
            eval_inputs.append(''.join([_['content'] for _ in prompt]))

        sampling_params = SamplingParams(
            temperature=0.0,
            top_p=1.0,
            max_tokens=16192,
            stop_token_ids=[llm.get_tokenizer().eos_token_id, llm.get_tokenizer().convert_tokens_to_ids("<|eot_id|>")]
        )

        outputs = llm.generate(eval_inputs, sampling_params)

        progress_bar = tqdm(total=len(eval_dataset), desc="Processing", leave=False)

        for idx, output in enumerate(outputs):
            response = output.outputs[0].text

            eval_dataset[split][idx]['Llama3.1-response'] = response

        progress_bar.close()


if __name__ == "__main__":
    with open('./prompt_for_API_ICL.pkl', 'rb') as fp:
        prompt_for_API_ICL = pickle.load(fp)

    model_id = f"/Pretrained_Language_Models/Meta-Llama-{args.model_version}-{args.model_size}B-Instruct"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    print('model_id:', model_id)
    print('load_path:', args.load_path)
    # tokenizer = AutoTokenizer.from_pretrained(model_id)
    # tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.padding_side = "left"

    random_state = 42
    tensor_parallel_size = 8
    if args.debug:
        tensor_parallel_size = 8
    if args.load_path == "":
        llm = LLM(model_id, tensor_parallel_size=tensor_parallel_size, dtype='bfloat16')

    else:
        llm = LLM(args.load_path, tensor_parallel_size=tensor_parallel_size)
    start = time.time()
    infer_result = predict_on_validation_BATCH_vllm(llm, prompt_for_API_ICL, external_data=True)

    timediff = time.time() - start

    minutes, seconds = divmod(timediff, 60)
    hours, minutes = divmod(minutes, 60)


    # Only save the results on the main process
    dump_path = f'synthesized-{args.task}_result.json'
    print(f"Time difference: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
    with open(dump_path, 'w') as fp:
        json.dump(infer_result, fp, indent=4)
    print('dump_path:', dump_path)